{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Learning a communication channel\n",
    "\n",
    "This model shows how to include synaptic plasticity in nengo models using the hPES learning rule. You will implement error-driven learning to learn to compute a simple communication channel. This is done by using an error signal provided by a neural ensemble to modulate the connection between two other ensembles."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# Setup the environment\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "import nengo\n",
    "from nengo.processes import WhiteSignal"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create the Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The model has parameters as described in the book. Note that hPES learning rule is built into Nengo 1.4 as mentioned in the book. The same rule can be implemented in Nengo 2.0 by combining the PES and the BCM rule as shown in the code.  Also, instead of using the 'gate' and 'switch' as described in the book, 'inhibit' population is used which serves the same purpose of turning off the learning by inhibiting the error population. For computing the actual error (which is required only for analysis), direct mode can be used by specifying the neuron_type as 'nengo.Direct()' while creating the 'actual_error' ensemble.\n",
    "\n",
    "\n",
    "When you run the model, you will see that the 'post' population gradually learns to compute the communication channel. In the model, you will inhibit the error population after 15 seconds to turn off learning and you will see that the 'post' population will still track the 'pre' population showing that the model has actually learned the input. \n",
    "\n",
    "The model can also learn other functions by using an appropriate error signal. For example to learn a square function, comment out the lines marked 'Communication Channel' and uncomment the lines marked 'Square' in the code. Run the model again and you will see that the model successfully learns the square function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "model = nengo.Network(label='Learning', seed=7)\n",
    "with model:\n",
    "    # Ensembles to represent populations\n",
    "    pre = nengo.Ensemble(50, dimensions=1)\n",
    "    post = nengo.Ensemble(50, dimensions=1)\n",
    "    error = nengo.Ensemble(100, dimensions=1)\n",
    "    actual_error = nengo.Ensemble(100, dimensions=1, neuron_type=nengo.Direct())\n",
    "    \n",
    "    # Actual Error = pre - post (direct mode)\n",
    "    # nengo.Connection(pre, actual_error, function=lambda x: x**2, transform=-1)  #Square\n",
    "    nengo.Connection(pre, actual_error, transform=-1)                             #Communication Channel\n",
    "    nengo.Connection(post, actual_error, transform=1)\n",
    "     \n",
    "    # Error = pre - post\n",
    "    # nengo.Connection(pre, error, function=lambda x: x**2, transform=-1)         #Square\n",
    "    nengo.Connection(pre, error, transform=-1, synapse=0.02)                      #Communication Channel\n",
    "    nengo.Connection(post, error, transform=1, synapse=0.02)\n",
    "    \n",
    "    # Connecting pre population to post population (communication channel)\n",
    "    conn = nengo.Connection(pre, post, function=lambda x: np.random.random(1),\n",
    "                            solver=nengo.solvers.LstsqL2(weights=True))\n",
    "    \n",
    "    # Adding the learning rule to the connection \n",
    "    conn.learning_rule_type = {'my_pes': nengo.PES(), 'my_bcm': nengo.BCM(learning_rate=1e-10)}\n",
    "    \n",
    "    # Error connections don't impart current\n",
    "    error_conn = nengo.Connection(error, conn.learning_rule['my_pes'])\n",
    "    \n",
    "    # Providing input to the model\n",
    "    input = nengo.Node(WhiteSignal(30, high=10))  # RMS = 0.5 by default\n",
    "    nengo.Connection(input, pre, synapse=0.02)    # Connecting input to the pre ensemble\n",
    "    \n",
    "    # function to inhibit the error population after 25 seconds\n",
    "    def inhib(t):\n",
    "        return 2.0 if t > 15.0 else 0.0\n",
    "    \n",
    "    # Connecting inhibit population to error population\n",
    "    inhibit = nengo.Node(inhib)\n",
    "    nengo.Connection(inhibit, error.neurons, transform=[[-1]] * error.n_neurons, synapse=0.01)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run the Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# Import the nengo_gui visualizer to run and visualize the model.\n",
    "from nengo_gui.ipython import IPythonViz\n",
    "IPythonViz(model, \"ch6-learn.py.cfg\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Press the play button in the visualizer to run the simulation. You should see the graphs as shown in the figure below.\n",
    "\n",
    "You will see that the post population initially doesn't track the pre population but eventually it learns the function (which is a communication channel in this case) and starts tracking the post population. After the error is inhibited (i.e., t>15s, shown by value zero in the error plot), the post population continues to track the pre population. The actual_error graph shows that there is significant error between the pre and the post populations at the begining which eventually gets reduced to almost zero as model learns the communication channel. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "from IPython.display import Image\n",
    "Image(filename='ch6-learn.png')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.10"
  },
  "widgets": {
   "state": {},
   "version": "1.1.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
